from typing import Union

from torch import nn

_activations = {
    "relu": nn.ReLU,
    "tanh": nn.Tanh,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,
    "leakyrelu": nn.LeakyReLU,
    "prelu": nn.PReLU,
}


def get_activation_class(activation: Union[str, nn.Module]):
    if isinstance(activation, nn.Module):
        return activation
    elif isinstance(activation, str):
        activation = activation.lower()
        return _activations[activation]
    else:
        raise ValueError(f"Invalid activation: {activation}")
